from argparse import ArgumentParser, Namespace
from pathlib import Path
from tempfile import TemporaryDirectory

import mmcv

try:
    from model_archiver.model_packaging import package_model
    from model_archiver.model_packaging_utils import ModelExportUtils
except ImportError:
    package_model = None


def mmdet2torchserve(
    config_file: str,
    checkpoint_file: str,
    output_folder: str,
    model_name: str,
    model_version: str = "1.0",
    force: bool = False,
):
    """Converts MMDetection model (config + checkpoint) to TorchServe `.mar`.

    Args:
        config_file:
            In MMDetection config format.
            The contents vary for each task repository.
        checkpoint_file:
            In MMDetection checkpoint format.
            The contents vary for each task repository.
        output_folder:
            Folder where `{model_name}.mar` will be created.
            The file created will be in TorchServe archive format.
        model_name:
            If not None, used for naming the `{model_name}.mar` file
            that will be created under `output_folder`.
            If None, `{Path(checkpoint_file).stem}` will be used.
        model_version:
            Model's version.
        force:
            If True, if there is an existing `{model_name}.mar`
            file under `output_folder` it will be overwritten.
    """
    mmcv.mkdir_or_exist(output_folder)

    config = mmcv.Config.fromfile(config_file)

    with TemporaryDirectory() as tmpdir:
        config.dump(f"{tmpdir}/config.py")

        args = Namespace(
            **{
                "model_file": f"{tmpdir}/config.py",
                "serialized_file": checkpoint_file,
                "handler": f"{Path(__file__).parent}/mmdet_handler.py",
                "model_name": model_name or Path(checkpoint_file).stem,
                "version": model_version,
                "export_path": output_folder,
                "force": force,
                "requirements_file": None,
                "extra_files": None,
                "runtime": "python",
                "archive_format": "default",
            }
        )
        manifest = ModelExportUtils.generate_manifest_json(args)
        package_model(args, manifest)


def parse_args():
    parser = ArgumentParser(
        description="Convert MMDetection models to TorchServe `.mar` format."
    )
    parser.add_argument("config", type=str, help="config file path")
    parser.add_argument("checkpoint", type=str, help="checkpoint file path")
    parser.add_argument(
        "--output-folder",
        type=str,
        required=True,
        help="Folder where `{model_name}.mar` will be created.",
    )
    parser.add_argument(
        "--model-name",
        type=str,
        default=None,
        help="If not None, used for naming the `{model_name}.mar`"
        "file that will be created under `output_folder`."
        "If None, `{Path(checkpoint_file).stem}` will be used.",
    )
    parser.add_argument(
        "--model-version", type=str, default="1.0", help="Number used for versioning."
    )
    parser.add_argument(
        "-f",
        "--force",
        action="store_true",
        help="overwrite the existing `{model_name}.mar`",
    )
    args = parser.parse_args()

    return args


if __name__ == "__main__":
    args = parse_args()

    if package_model is None:
        raise ImportError(
            "`torch-model-archiver` is required."
            "Try: pip install torch-model-archiver"
        )

    mmdet2torchserve(
        args.config,
        args.checkpoint,
        args.output_folder,
        args.model_name,
        args.model_version,
        args.force,
    )
